{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vPZaHYB21OQi"
      },
      "source": [
        "# Gradient Based Constrained Decoding Demo\n",
        "\n",
        "Licensed under the Apache License, Version 2.0.\n",
        "\n",
        "This method is based upon [Gradient-based Inference for Networks\n",
        "with Output Constraints](https://arxiv.org/pdf/1707.08608.pdf) by Lee et al."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JXWCFnte1Qz_"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import numpy as np\n",
        "import random\n",
        "import tensorflow as tf\n",
        "\n",
        "from psl import constrained_evaluation as eval_model  # local file import from experimental.language_structure\n",
        "from psl import data  # local file import from experimental.language_structure\n",
        "from psl import psl_model_multiwoz as model  # local file import from experimental.language_structure"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UKMZVzuP1Wm3"
      },
      "source": [
        "# Dataset and Task\n",
        "\n",
        "We study constrained decoding through the task of dialog structure prediction. Dialog structure is the high level representation of the flow of a dialog, where nodes represent abstract topics or dialog acts that statements would fit into and edges represent topic changes.\n",
        "\n",
        "To verify our method we ideally would like to test it over a multi-goal oriented dialog corpus such as [MultiWoZ 2.0](https://arxiv.org/pdf/1907.01669.pdf), created by Mihail Eric et. al. Unfortunately, this corpus does not have a ground truth dialog structure, therefore, we use a [Synthetic Multi-WoZ](https://almond-static.stanford.edu/papers/multiwoz-acl2020.pdf) dataset created by Giovanni Campagna et. al."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7q08dilwptzi"
      },
      "outputs": [],
      "source": [
        "# ========================================================================\n",
        "# Constants\n",
        "# ========================================================================\n",
        "DATA_PATH = ''\n",
        "\n",
        "RULE_WEIGHTS = np.array([1.0, 20.0, 5.0, 5.0, 5.0, 10.0, 5.0, 20.0, 5.0, 5.0, 5.0, 10.0])\n",
        "RULE_NAMES = ('rule_1', 'rule_2', 'rule_3', 'rule_4', 'rule_5', 'rule_6', 'rule_7', 'rule_8', 'rule_9', 'rule_10', 'rule_11', 'rule_12')\n",
        "\n",
        "ALPHAS = [0.1]\n",
        "GRAD_STEPS = [10, 50, 100, 500]\n",
        "LEARNING_RATES = [0.0001, 0.0005, 0.001, 0.01]\n",
        "\n",
        "# ========================================================================\n",
        "# Seed Data\n",
        "# ========================================================================\n",
        "SEED = random.randint(-10000000, 10000000)\n",
        "print(\"Seed: %d\" % SEED)\n",
        "tf.random.set_seed(SEED)\n",
        "\n",
        "# ========================================================================\n",
        "# Load Data\n",
        "# ========================================================================\n",
        "DATA = data.load_json(DATA_PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "xbU8cVI6p-Hv"
      },
      "outputs": [],
      "source": [
        "#@title Config\n",
        "config = {\n",
        "    'default_seed': 4,\n",
        "    'batch_size': 128,\n",
        "    'max_dialog_size': 10,\n",
        "    'max_utterance_size': 40,\n",
        "    'class_map': {\n",
        "        'accept': 0,\n",
        "        'cancel': 1,\n",
        "        'end': 2,\n",
        "        'greet': 3,\n",
        "        'info_question': 4,\n",
        "        'init_request': 5,\n",
        "        'insist': 6,\n",
        "        'second_request': 7,\n",
        "        'slot_question': 8,\n",
        "    },\n",
        "    'accept_words': ['yes', 'great'],\n",
        "    'cancel_words': ['no'],\n",
        "    'end_words': ['thank', 'thanks'],\n",
        "    'greet_words': ['hello', 'hi'],\n",
        "    'info_question_words': ['address', 'phone'],\n",
        "    'insist_words': ['sure', 'no'],\n",
        "    'slot_question_words': ['what', '?'],\n",
        "    'includes_word': -1,\n",
        "    'excludes_word': -2,\n",
        "    'mask_index': 0,\n",
        "    'accept_index': 1,\n",
        "    'cancel_index': 2,\n",
        "    'end_index': 3,\n",
        "    'greet_index': 4,\n",
        "    'info_question_index': 5,\n",
        "    'insist_index': 6,\n",
        "    'slot_question_index': 7,\n",
        "    'utterance_mask': -1,\n",
        "    'last_utterance_mask': -2,\n",
        "    'pad_utterance_mask': -3,\n",
        "    'shuffle_train': True,\n",
        "    'shuffle_test': False,\n",
        "    'train_epochs': 1,\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Bwf1tzfnqKhH"
      },
      "outputs": [],
      "source": [
        "#@title Prepare Dataset\n",
        "train_dialogs = data.add_features(\n",
        "    DATA['train_data'],\n",
        "    vocab_mapping=DATA['vocab_mapping'],\n",
        "    accept_words=config['accept_words'],\n",
        "    cancel_words=config['cancel_words'],\n",
        "    end_words=config['end_words'],\n",
        "    greet_words=config['greet_words'],\n",
        "    info_question_words=config['info_question_words'],\n",
        "    insist_words=config['insist_words'],\n",
        "    slot_question_words=config['slot_question_words'],\n",
        "    includes_word=config['includes_word'],\n",
        "    excludes_word=config['excludes_word'],\n",
        "    accept_index=config['accept_index'],\n",
        "    cancel_index=config['cancel_index'],\n",
        "    end_index=config['end_index'],\n",
        "    greet_index=config['greet_index'],\n",
        "    info_question_index=config['info_question_index'],\n",
        "    insist_index=config['insist_index'],\n",
        "    slot_question_index=config['slot_question_index'],\n",
        "    utterance_mask=config['utterance_mask'],\n",
        "    pad_utterance_mask=config['pad_utterance_mask'],\n",
        "    last_utterance_mask=config['last_utterance_mask'],\n",
        "    mask_index=config['mask_index'])\n",
        "train_data = data.pad_dialogs(train_dialogs, config['max_dialog_size'],\n",
        "                              config['max_utterance_size'])\n",
        "raw_train_labels = data.one_hot_string_encoding(DATA['train_truth_dialog'],\n",
        "                                                config['class_map'])\n",
        "train_labels = data.pad_one_hot_labels(raw_train_labels,\n",
        "                                       config['max_dialog_size'],\n",
        "                                       config['class_map'])\n",
        "train_ds = data.list_to_dataset(train_data[0], train_labels[0],\n",
        "                                config['shuffle_train'],\n",
        "                                config['batch_size'])\n",
        "\n",
        "test_dialogs = data.add_features(\n",
        "    DATA['test_data'],\n",
        "    vocab_mapping=DATA['vocab_mapping'],\n",
        "    accept_words=config['accept_words'],\n",
        "    cancel_words=config['cancel_words'],\n",
        "    end_words=config['end_words'],\n",
        "    greet_words=config['greet_words'],\n",
        "    info_question_words=config['info_question_words'],\n",
        "    insist_words=config['insist_words'],\n",
        "    slot_question_words=config['slot_question_words'],\n",
        "    includes_word=config['includes_word'],\n",
        "    excludes_word=config['excludes_word'],\n",
        "    accept_index=config['accept_index'],\n",
        "    cancel_index=config['cancel_index'],\n",
        "    end_index=config['end_index'],\n",
        "    greet_index=config['greet_index'],\n",
        "    info_question_index=config['info_question_index'],\n",
        "    insist_index=config['insist_index'],\n",
        "    slot_question_index=config['slot_question_index'],\n",
        "    utterance_mask=config['utterance_mask'],\n",
        "    pad_utterance_mask=config['pad_utterance_mask'],\n",
        "    last_utterance_mask=config['last_utterance_mask'],\n",
        "    mask_index=config['mask_index'])\n",
        "test_data = data.pad_dialogs(test_dialogs, config['max_dialog_size'],\n",
        "                             config['max_utterance_size'])\n",
        "raw_test_labels = data.one_hot_string_encoding(DATA['test_truth_dialog'],\n",
        "                                               config['class_map'])\n",
        "test_labels = data.pad_one_hot_labels(raw_test_labels,\n",
        "                                      config['max_dialog_size'],\n",
        "                                      config['class_map'])\n",
        "test_ds = data.list_to_dataset(test_data[0], test_labels[0],\n",
        "                               config['shuffle_test'],\n",
        "                               config['batch_size'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "caUermnvZbk7"
      },
      "outputs": [],
      "source": [
        "#@title Helper Functions\n",
        "def class_confusion_matrix(preds, labels, config):\n",
        "  correct = 0\n",
        "  incorrect = 0\n",
        "\n",
        "  class_map = config['class_map']\n",
        "  reverse_class_map = {v: k for k, v in class_map.items()}\n",
        "  class_confusion_matrix_dict = {key: {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} for key, _ in class_map.items()}\n",
        "  class_confusion_matrix_dict['total'] = {'correct': 0, 'incorrect': 0}\n",
        "\n",
        "  for pred_list, label_list in zip(preds, labels):\n",
        "    for pred, label in zip(pred_list, label_list):\n",
        "      if class_map[label] == pred:\n",
        "        class_confusion_matrix_dict['total']['correct'] += 1\n",
        "        class_confusion_matrix_dict[label]['tp'] += 1\n",
        "      else:\n",
        "        class_confusion_matrix_dict['total']['incorrect'] += 1\n",
        "        class_confusion_matrix_dict[label]['fp'] += 1\n",
        "        class_confusion_matrix_dict[reverse_class_map[pred.numpy()]]['fn'] += 1\n",
        "\n",
        "      for key in class_map:\n",
        "        if key == label or key == reverse_class_map[pred.numpy()]:\n",
        "          continue\n",
        "        class_confusion_matrix_dict[reverse_class_map[pred.numpy()]]['tn'] += 1\n",
        "\n",
        "  return class_confusion_matrix_dict\n",
        "\n",
        "def precision_recall_f1(confusion_matrix):\n",
        "  if (confusion_matrix['tp'] + confusion_matrix['fp']) == 0:\n",
        "    precision = 0.0\n",
        "  else:\n",
        "    precision = confusion_matrix['tp'] / (confusion_matrix['tp'] + confusion_matrix['fp'])\n",
        "\n",
        "  if (confusion_matrix['tp'] + confusion_matrix['fn']) == 0:\n",
        "    recall = 0.0\n",
        "  else:\n",
        "    recall = confusion_matrix['tp'] / (confusion_matrix['tp'] + confusion_matrix['fn'])\n",
        "\n",
        "  if (precision + recall) == 0:\n",
        "    f1 = 0.0\n",
        "  else:\n",
        "    f1 = 2.0 * (precision * recall / (precision + recall))\n",
        "\n",
        "  return precision, recall, f1\n",
        "\n",
        "def print_metrics(confusion_matrix):\n",
        "  cat_accuracy = confusion_matrix['total']['correct'] / (confusion_matrix['total']['incorrect'] + confusion_matrix['total']['correct'])\n",
        "  print(\"Categorical Accuracy: %0.4f\" % (cat_accuracy,))\n",
        "  values = []\n",
        "  for key, value in confusion_matrix.items():\n",
        "    if key == 'total':\n",
        "      continue\n",
        "    precision, recall, f1 = precision_recall_f1(value)\n",
        "\n",
        "    print(\"Class: %s Precision: %0.4f  Recall: %0.4f  F1: %0.4f\" % (key.ljust(15), precision, recall, f1))\n",
        "    values.append(str(precision) + \",\" + str(recall) + \",\" + str(f1))\n",
        "  return values, cat_accuracy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QLCc30du1huD"
      },
      "source": [
        "# Neural Model\n",
        "\n",
        "Below is a simple neural model for supervised structure prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "zwRtAijRqSRu"
      },
      "outputs": [],
      "source": [
        "#@title Create Neural Model\n",
        "def build_model(input_size, learning_rate=0.001):\n",
        "  \"\"\"Build simple neural model for class prediction.\"\"\"\n",
        "  input_layer = tf.keras.layers.Input(input_size)\n",
        "  hidden_layer_1 = tf.keras.layers.Dense(1024)(input_layer)\n",
        "  hidden_layer_2 = tf.keras.layers.Dense(\n",
        "      512, activation='sigmoid')(\n",
        "          hidden_layer_1)\n",
        "  output = tf.keras.layers.Dense(\n",
        "      9, activation='softmax',\n",
        "      kernel_regularizer=tf.keras.regularizers.l2(1.0))(\n",
        "          hidden_layer_2)\n",
        "\n",
        "  model = tf.keras.Model(input_layer, output)\n",
        "\n",
        "  model.compile(\n",
        "      optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),\n",
        "      loss='categorical_crossentropy',\n",
        "      metrics=['accuracy'])\n",
        "\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e9SzqHwKrdIF"
      },
      "outputs": [],
      "source": [
        "def run_non_constrained(train_ds, test_ds, test_labels, config, learning_rate):\n",
        "  test_model = build_model([config['max_dialog_size'], config['max_utterance_size']], learning_rate=learning_rate)\n",
        "  test_model.fit(train_ds, epochs=config['train_epochs'])\n",
        "\n",
        "  logits = test_model.predict(test_ds)\n",
        "  predictions = tf.math.argmax(logits, axis=-1)\n",
        "\n",
        "  confusion_matrix = class_confusion_matrix(predictions, test_labels, config)\n",
        "  metrics, cat_accuracy = print_metrics(confusion_matrix)\n",
        "\n",
        "  return test_model, metrics, cat_accuracy\n",
        "\n",
        "test_model, metrics, cat_accuracy = run_non_constrained(train_ds, test_ds, DATA['train_truth_dialog'], config, 0.0001)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UMQtBEO41eyh"
      },
      "source": [
        "# Gradient Based Constraint Decoding\n",
        "\n",
        "Rules:\n",
        "\n",
        "1. !FirstStatement(S) -\u003e !State(S, 'greet')\n",
        "2. FirstStatement(S) \u0026 HasGreetWord(S) -\u003e State(S, 'greet')\n",
        "3. FirstStatement(S) \u0026 !HasGreetWord(S) -\u003e State(S, 'init_request')\n",
        "4. PreviousStatement(S1, S2) \u0026 State(S2, 'init_request') -\u003e State(S1, 'second_request')\n",
        "5. PreviousStatement(S1, S2) \u0026 !State(S2, 'greet') -\u003e !State(S1, 'init_request')\n",
        "6. PreviousStatement(S1, S2) \u0026 State(S2, 'greet') -\u003e State(S1, 'init_request')\n",
        "7. LastStatement(S) \u0026 HasEndWord(S) -\u003e State(S, 'end')\n",
        "8. LastStatement(S) \u0026 HasAcceptWord(S) -\u003e State(S, 'accept')\n",
        "9. NextStatement(S1, S2) \u0026 State(S2, 'end') \u0026 HasCancelWord(S1) -\u003e State(S1, 'cancel')\n",
        "10. PreviousStatement(S1, S2) \u0026 State(S2, 'second_request') \u0026 HasInfoQuestionWord(S1) -\u003e State(S1, 'info_question')\n",
        "11. LastStatement(S) \u0026 HasInsistWord(S) -\u003e State(S, 'insist')\n",
        "12. PreviousStatement(S1, S2) \u0026 State(S2, 'second_request') \u0026 HasSlotQuestionWord(S1) -\u003e State(S1, 'slot_question')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oCj-VXXNsTX2"
      },
      "outputs": [],
      "source": [
        "def run_constrained(test_model, rule_weights, rule_names, test_ds, test_labels, config, alpha, grad_step):\n",
        "  psl_constraints = model.PSLModelMultiWoZ(rule_weights, rule_names, config=config)\n",
        "  logits = eval_model.evaluate_constrained_model(test_model, test_ds, psl_constraints, grad_steps=grad_step, alpha=alpha)\n",
        "  predictions = tf.math.argmax(tf.concat(logits, axis=0), axis=-1)\n",
        "\n",
        "  confusion_matrix = class_confusion_matrix(predictions, test_labels, config)\n",
        "  metrics, cat_accuracy = print_metrics(confusion_matrix)\n",
        "\n",
        "  return predictions, metrics, cat_accuracy\n",
        "\n",
        "predictions, metrics, cat_accuracy = run_constrained(test_model, RULE_WEIGHTS, RULE_NAMES, test_ds, DATA['test_truth_dialog'], config, 0.1, 500)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gI4uztt5Wht_"
      },
      "source": [
        "# Qualitative Analysis"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TdfArRMmuwlZ"
      },
      "outputs": [],
      "source": [
        "def recover_utterances(dialog, vocab_map):\n",
        "  sentences = []\n",
        "  for utterance in dialog:\n",
        "    sentence = ''\n",
        "    for word in utterance:\n",
        "      if word in [0, -1, -2, -3]:\n",
        "        continue\n",
        "      sentence += ' ' + vocab_map[word]\n",
        "    if sentence != '':\n",
        "      sentences.append(sentence)\n",
        "  return sentences\n",
        "\n",
        "def print_dialog(dialog_index, vocab_map, class_map, data, predictions):\n",
        "  vocab_map = {v: k for k, v in vocab_map.items()}\n",
        "  class_map = {v: k for k, v in class_map.items()}\n",
        "  utterances = recover_utterances(test_data[0][dialog_index], vocab_map)\n",
        "\n",
        "  for utterance_index in range(len(utterances)):\n",
        "    key = predictions[dialog_index][utterance_index]\n",
        "    print(\"Prediction: %s Utterance: %s\" % (class_map[int(key)].ljust(15), utterances[utterance_index]))\n",
        "\n",
        "print(\"\\nDialog Greet\")\n",
        "print('-' * 50)\n",
        "print_dialog(27, DATA['vocab_mapping'], config['class_map'], test_data, predictions)\n",
        "print(\"\\nDialog End\")\n",
        "print('-' * 50)\n",
        "print_dialog(6, DATA['vocab_mapping'], config['class_map'], test_data, predictions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DbdDX4atVdID"
      },
      "source": [
        "# Run Hyperparameter Grid"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vq80pblMIhci"
      },
      "outputs": [],
      "source": [
        "def run_grid(train_ds, test_ds, test_data, test_labels, rule_weights, rule_names, vocab_mapping, config, alphas, grad_steps, learning_rates):\n",
        "  character_size = 80\n",
        "\n",
        "  constrained_metrics = []\n",
        "  non_constrained_metrics = []\n",
        "  constrained_cat_accuracies = []\n",
        "  non_constrained_cat_accuracies = []\n",
        "\n",
        "  for alpha in alphas:\n",
        "    for grad_step in grad_steps:\n",
        "      for learning_rate in learning_rates:\n",
        "        print('\\n' + '=' * character_size)\n",
        "        print(\"Running: Alpha - %0.5f   Gradient Steps - %d   Learning Rate - %0.5f\" % (alpha, grad_step, learning_rate))\n",
        "        print('=' * character_size)\n",
        "\n",
        "        print('\\nNon-Constrained')\n",
        "        print('-' * character_size)\n",
        "        test_model, metrics, cat_accuracy = run_non_constrained(train_ds, test_ds, DATA['test_truth_dialog'], config, learning_rate=learning_rate)\n",
        "        non_constrained_metrics.append(metrics)\n",
        "        non_constrained_cat_accuracies.append(cat_accuracy)\n",
        "\n",
        "        print('\\nConstrained')\n",
        "        print('-' * character_size)\n",
        "        predictions, metrics, cat_accuracy = run_constrained(test_model, rule_weights, rule_names, test_ds, DATA['test_truth_dialog'], config, alpha=alpha, grad_step=grad_step)\n",
        "        constrained_metrics.append(metrics)\n",
        "        constrained_cat_accuracies.append(cat_accuracy)\n",
        "\n",
        "        print(\"\\nDialog Greet\")\n",
        "        print('-' * 50)\n",
        "        print_dialog(11, DATA['vocab_mapping'], config['class_map'], test_data, predictions)\n",
        "        print(\"\\nDialog End\")\n",
        "        print('-' * 50)\n",
        "        print_dialog(6, DATA['vocab_mapping'], config['class_map'], test_data, predictions)\n",
        "\n",
        "  return non_constrained_metrics, constrained_metrics, non_constrained_cat_accuracies, constrained_cat_accuracies"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "gradient_based_constraint_decoding_demo.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1xEyh8B35U7Zhe11oYlIOSo33vSPlxxuA",
          "timestamp": 1633637977853
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
